from Marchines import CART
# help(CART)
if __name__ == "__main__":
    # # 测试数据集
    # testMat = mat(eye(4))
    # print testMat
    # print type(testMat)
    # mat0, mat1 = binSplitDataSet(testMat, 1, 0.5)
    # print mat0, '\n-----------\n', mat1

    # # 回归树
    # myDat = loadDataSet('input/9.CART/data1.txt')
    # # myDat = loadDataSet('input/9.CART/data2.txt')
    # # print 'myDat=', myDat
    # myMat = mat(myDat)
    # # print 'myMat=',  myMat
    # myTree = createTree(myMat)
    # print myTree

    # # 1. 预剪枝就是：提起设置最大误差数和最少元素数
    # myDat = loadDataSet('input/9.CART/data3.txt')
    # myMat = mat(myDat)
    # myTree = createTree(myMat, ops=(0, 1))
    # print myTree

    # # 2. 后剪枝就是：通过测试数据，对预测模型进行合并判断
    # myDatTest = loadDataSet('input/9.CART/data3test.txt')
    # myMat2Test = mat(myDatTest)
    # myFinalTree = prune(myTree, myMat2Test)
    # print '\n\n\n-------------------'
    # print myFinalTree

    # # --------
    # # 模型树求解
    # myDat = loadDataSet('input/9.CART/data4.txt')
    # myMat = mat(myDat)
    # myTree = createTree(myMat, modelLeaf, modelErr)
    # print myTree

    # # 回归树 VS 模型树 VS 线性回归
    trainMat = CART.mat(CART.loadDataSet('bikeSpeedVsIq_train.txt'))
    testMat = CART.mat(CART.loadDataSet('bikeSpeedVsIq_test.txt'))
    # # 回归树
    myTree1 = CART.createTree(trainMat, ops=(1, 20))
    print(myTree1)
    yHat1 = CART.createForeCast(myTree1, testMat[:, 0])
    print("--------------\n")
    # print yHat1
    # print "ssss==>", testMat[:, 1]
    print("回归树:", CART.corrcoef(yHat1, testMat[:, 1],rowvar=0)[0, 1])

    # 模型树
    myTree2 = CART.createTree(trainMat, CART.modelLeaf, CART.modelErr, ops=(1, 20))
    yHat2 = CART.createForeCast(myTree2, testMat[:, 0], CART.modelTreeEval)
    print(myTree2)
    print("模型树:", CART.corrcoef(yHat2, testMat[:, 1],rowvar=0)[0, 1])

    # 线性回归
    ws, X, Y = CART.linearSolve(trainMat)
    print(ws)
    m = len(testMat[:, 0])
    yHat3 = CART.mat(CART.zeros((m, 1)))
    for i in range(CART.shape(testMat)[0]):
        yHat3[i] = testMat[i, 0]*ws[1, 0] + ws[0, 0]
    print("线性回归:", CART.corrcoef(yHat3, testMat[:, 1],rowvar=0)[0, 1])
